#include <utility>
#include <stdio.h>
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include <atomic>
#include "net/tcp_client.hpp"
#include "net/io_service_pool.hpp"
#include "base/logger.h"

using namespace easyasio::net;
class ClientMng;

#define START_INDEX 1000
#define MAX_INDEX 6000
#define INTERNAL_MSEC 1

class Session 
{
public:
    Session(asio::io_service& loop, const std::string& remote_addr, int remote_port, ClientMng* owner)
        : client_(std::make_shared<TcpClient>(loop, remote_addr, remote_port)), 
          owner_(owner), 
          loop_timer_(loop)
    {
        next_send_index_ = START_INDEX;
        next_recv_index_ = START_INDEX;
        client_->setConnectionCallback(
            std::bind(&Session::onConnection, this, std::placeholders::_1));
        client_->setMessageCallback(
            std::bind(&Session::onMessage, this, std::placeholders::_1, std::placeholders::_2));
    }

    void start() { client_->asynConnect(); }
    void stop() 
    { 
        std::cout << next_recv_index_ << std::endl;
        client_->stop(); 
    }

private:
    void sendMsg();
    void onConnection(const TcpConnectionPtr& conn);
    void onDisconnect(const TcpConnectionPtr& conn);
    void handleTimter()
    {
        sendMsg();
    }

    void setTimer(int milliseconds)
	{
        loop_timer_.expires_from_now(std::chrono::milliseconds(milliseconds));
	    loop_timer_.async_wait(std::bind(&Session::handleTimter, this));
	}

    void onMessage(const TcpConnectionPtr& conn, Buffer* buf);

private:
    std::shared_ptr<TcpClient> client_;
    ClientMng* owner_;

    asio::basic_waitable_timer<std::chrono::system_clock> loop_timer_;
    int next_send_index_;
    int next_recv_index_;
    TcpConnectionPtr conn_;
};

class ClientMng
{
 public:
  ClientMng(const std::string& remote_addr, int remote_port, int sessionCount, int timeout, int threadCount)
    : threadPool_(threadCount, true),
      loop_timer_(threadPool_.get_io_service()),
      sessionCount_(sessionCount),
      timeout_(timeout)
{
    complete_num_ = 0;
    numConnected_ = 0;
    loop_timer_.expires_from_now(std::chrono::seconds(timeout));
    loop_timer_.async_wait(std::bind(&ClientMng::handleTimeout, this));

    for (int i = 0; i < sessionCount; ++i)
    {
        char buf[32];
        #ifdef WIN32
        sprintf_s(buf, sizeof buf, "C%05d", i);
        #else
        snprintf(buf, sizeof buf, "C%05d", i);
        #endif

        std::shared_ptr<Session> session = std::make_shared<Session>(
            threadPool_.get_io_service(), remote_addr, remote_port, this);
        session->start();
        sessions_.push_back(session);
    }
}

void onConnect()
{
    ++numConnected_;
    if (numConnected_ == sessionCount_)
    {
        LOG_WARN("all connected");
    }
}

void run()
{
    threadPool_.run();
}

void onDisconnect()
{
    --numConnected_;
    if (numConnected_ == 0)
    {
        LOG_WARN("all disconnected");
        threadPool_.stop();
    }
}

void increaseComplete()
{
    complete_num_++;
    std::cout << "complete_num_ = " << complete_num_ << std::endl;
    if (complete_num_ == sessionCount_)
        std::for_each(sessions_.begin(), sessions_.end(), std::mem_fn(&Session::stop));
}

private:
    void handleTimeout()
    {
        LOG_WARN("stop");
        std::for_each(sessions_.begin(), sessions_.end(), std::mem_fn(&Session::stop));
    }

private:
    io_service_pool threadPool_;
    asio::basic_waitable_timer<std::chrono::system_clock> loop_timer_;
    int sessionCount_;
    int timeout_;

    std::vector<std::shared_ptr<Session>> sessions_;
    std::atomic_int numConnected_;
    std::atomic_int complete_num_;
    std::shared_ptr<std::thread> iopool_thread_;
};

void Session::onDisconnect(const TcpConnectionPtr& conn)
{
    owner_->onDisconnect();
}

void Session::onConnection(const TcpConnectionPtr& conn)
{
    if (conn->connected())
    {
        setTimer(INTERNAL_MSEC);
        conn->setTcpNoDelay(true);
        conn_ = conn;
        owner_->onConnect();
        conn->setCloseCallback(std::bind(&Session::onDisconnect, this, std::placeholders::_1));
    }
    else
    {
        owner_->onDisconnect();
    }
}

void Session::sendMsg()
{
    std::unique_ptr<char[]> sendbuf(new char[2 + next_send_index_]);
    *((uint16_t*)(sendbuf.get())) = 2 + next_send_index_;
    memset(sendbuf.get() + 2, 'a' + next_send_index_ % 26, next_send_index_);
    conn_->send(std::string(sendbuf.get(), 2 + next_send_index_));
    next_send_index_++;
    if (next_send_index_ > MAX_INDEX)
    {
        return;
    }

    setTimer(INTERNAL_MSEC);
}

void Session::onMessage(const TcpConnectionPtr& conn, Buffer* buf)
{
    if (!conn || !conn->connected())
        return;

    while (1) {
        if (buf->readableBytes() <= 2)
            return;

        uint16_t* lengptr = (uint16_t*)(buf->peek());
        if (*lengptr > buf->readableBytes())
            return;

        const char* data = buf->peek() + 2;
        int index = (*lengptr - 2);
        if (next_recv_index_ != index)
        {
            std::cout << "error" << std::endl;
            assert(0);
            exit(-1);
        }
        for (int i = 0; i < index; ++i)
        {
            if (data[i] != index % 26 + 'a')
            {
                std::cout << "error" << std::endl;
                assert(0);
                exit(-1);
            }
        }
        if (next_recv_index_ % 1000 == 0)
            std::cout << (void*)this << "recv " << next_recv_index_ << std::endl;
        buf->retrieve(*lengptr);
        next_recv_index_++;

        if (next_recv_index_ > MAX_INDEX)
        {
            std::cout << "session " << (void*)this << "recv complete " << std::endl;
            owner_->increaseComplete();
            return;
        }
    }
}

int main(int argc, char* argv[])
{
    if (argc != 6)
    {
        fprintf(stderr, "Usage: client <host_ip> <port> <threads>");
        fprintf(stderr, "<sessions> <time>\n");
    }
    else
    {
        easyasio::base::Logger::instance().setLevel(easyasio::base::Logger::LOGLEVEL_TRACE);
        const char* ip = argv[1];
        uint16_t port = static_cast<uint16_t>(atoi(argv[2]));
        int threadCount = atoi(argv[3]);
        int sessionCount = atoi(argv[4]);
        int timeout = atoi(argv[5]);

        ClientMng client(ip, port, sessionCount, timeout, threadCount);
        client.run();
    }

    return 0;
}

